import sys,os
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__),'..'))
sys.path.append(project_root)
sys.path.append(r'../witin_nn/nn')

import torch
import torch.optim as optim
from matplotlib import pyplot as plt
from gzymodel import ResNet18
from train_fun import load_dataset, train_runner, test_runner, fig_plot


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on {'GPU' if torch.cuda.is_available() else 'CPU'}")
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
model = ResNet18(quant=False).to(device)

trainloader, testloader = load_dataset()
optimizer = optim.Adam(model.parameters(), lr=0.001)

epoch = 20
Loss = []
Accuracy = []
test_Loss = []
test_Accuracy = []

for epoch in range(0, epoch):
    # for name, param in model.named_parameters():
    #     print(f"Name: {name}, Parameter: {param.data}, Gradient: {param.grad}")
    # 确保目录存在
    directory = '../model_params/witin/float'
    os.makedirs(directory, exist_ok=True)

    file_path = os.path.join(directory, 'ResNet18_param_' + str(epoch) + '.pth')
    # file_path = '../model_params/witin/float/ResNet18_param_' + str(epoch) + '.pth'
    loss, acc = train_runner(model, device, trainloader, optimizer, file_path,qat=False,nat=False,save_interval=1000)
    print("\nepcoh: ", epoch)
    print("train: loss, acc", loss, acc)
    Loss.append(loss)
    Accuracy.append(acc)
    test_loss, test_acc = test_runner(model, device, testloader)
    test_Loss.append(test_loss) 
    test_Accuracy.append(test_acc)
    print("test : loss, acc", test_loss, test_acc)

print('Finished Training')
# print('train_acc: ', Accuracy)
# print('test_acc: ', test_Accuracy)
fig_plot(Loss, test_Loss, Accuracy, test_Accuracy)